#!/bin/bash
set -x
start_time="$(date +%FT%T)"

if [[ "$master_addr" == "" && "$master_port" == "" ]]; then
  nnodes=1
  node_rank=0
  master_port=12345
  nproc_per_node=${nproc_per_node:=$(nvidia-smi --list-gpus | wc -l)}
  torchrun \
  --master_port=$master_port \
  --node_rank=$node_rank \
  --nproc_per_node=$nproc_per_node \
  --nnodes=$nnodes \
  "$@"
else
  torchrun \
  --master_addr=$master_addr \
  --master_port=$master_port \
  --node_rank=$node_rank \
  --nproc_per_node=$nproc_per_node \
  --nnodes=$nnodes \
  "$@"
fi

echo "start_time: $start_time"